# #####################################
# #                                   #
# #    NOT FULLY FUNCTIONAL CODE      #
# #                                   #
# #####################################
#
# This file provides snippets of the code used to create the EPD model
# described in the paper.  NOTE: this is not code that is used to train the
# model--this is simply the code needed to construct the model that would then
# be ready for training.  And even then, the code is not likely 100% complete.
# But these are exact code snippets taken from code used to generate the models
# in the paper.  Some knowledge of Keras and Tensorflow is likely necessary to
# understand these code snippets.
#
# The create_model function at the very end is what ultimately creates a model
# that is ready for training.  It expects to be passed in the input and label
# shapes, which theoretically, you'd have if you had a dataset prepped for
# training.  The shape for training data used in the paper was (batch, height,
# width, channels).  The shape for label data was the same, with channels = 1.
#
# Note that this code was copy/pasted out of multiple files from our source
# code into this file and this file was never explicitly run.  Bugs from that
# process may have been introduced.

#
# Imports
#

from typing import Any, Dict, List, Optional, Sequence
from ml_collections import config_dict
import tensorflow.compat.v2 as tf

#
# Channel Encoding Model
#

"""Generates a channel encoding model.

Replaces a single channel of categorical data in an input tensor with a set of
channels that represent a learned embedding of that categorical data.  The
replaced channel must be the last channel (simply permute the input before this
model is applied if the categorical channel is not the last), and the new
channels added are added as the last channels in the output.

The slice of the input tensor that contains the categorical data is assumed to
be type dfloat, but with values that correspond to whole-valued integers that
range between 0.0 and float(cfg.num_values - 1).

The input to the model must have the shape (b, S, c) where b is the batch
dimension, S is the set of spatial/temporal dimensions and c is the channel
dimension.

For example, if the input was shaped (b, t, h, w, c=5), and cfg.size=5, then the
emitted shape would be (b, t, h, w, c=5-1+5=9) where channels 0, 1, 2 and 3 were
the original channels, and channels 4 through 8 would be the newly added
embedding channels.  The original channel 4, the categorical channel, is
removed.
"""

class RemoveLastChannel(tf.keras.layers.Layer):
  """Removes the last channel of the input tensor."""

  def call(self, network):
    return network[..., :-1]


class LastChannelOneHot(tf.keras.layers.Layer):
  """Builds a one-hot encoding for the last channel of the input tensor.

  The last channel of the input tensor is assumed to be a channel with
  categorical values that range between 0 and some max value (though possible
  encoded as floats).
  """

  def __init__(self, num_values: int = 10, **kwargs):
    super(LastChannelOneHot, self).__init__(**kwargs)
    self.num_values = num_values

  def get_config(self):
    config = super(LastChannelOneHot, self).get_config()
    config.update({'num_values': self.num_values})
    return config

  def call(self, network):
    # Strip off just the last channel and cast to an integer value.
    network = network[..., -1]
    network = tf.cast(network, dtype=tf.int32)

    #  Transform that slice into a one-hot encoding.
    network = tf.one_hot(
        network,
        off_value=0.0,
        on_value=1.0,
        depth=self.num_values,
        dtype=tf.float32)

    return network


class ChannelEncodingModel():

  @classmethod
  def get_default_config(cls)-> config_dict.ConfigDict:
    """Gets the config for the encoding."""
    cfg = config_dict.ConfigDict()

    # The number of discrete categories in the input, and thus, the number of
    # layers in the one-hot-encoding.  NOTE: large values here will require lots
    # of memory.  This is not a hyperparameter, its value needs to be the
    # maximum value seen in the categorical data.
    #
    # WARNING: All values seen in the data that are greater than this value will
    # be treated identically in the embedding.  There is no verification that
    # such values do not exist.
    cfg.num_values = 40

    # The number of channels in the embedding, and thus the number of embedding
    # channels added to the input.
    cfg.size = 5

    # The size of the convolutional kernel used when making the embedding.  A
    # value of 1 results in the most interpretable results, but larger values
    # may aid in overall performance.
    cfg.kernel_size = 1

    # If non-empty, the name of the activation unit to add after the encoding.
    cfg.activation = ''

    cfg.lock()

    return cfg

  def get_custom_objects(self) -> Dict[str, Any]:
    return {'LastChannelOneHot': LastChannelOneHot,
            'RemoveLastChannel': RemoveLastChannel}

  def verify_input_shapes(
      self, inputs: Sequence[tf.Tensor], cfg: Optional[config_dict.ConfigDict]):
    if len(inputs) != 1:
      raise ValueError('Expecting a single input.  Found: ', len(inputs))
    if len(inputs[0].shape) <= 3:
      raise ValueError(
          'Expecting a batch dimension, at least 1 spatial dimension, and one '
          'channel dimension.')

  def create(
      self,
      inputs: Sequence[tf.Tensor],
      cfg: config_dict.ConfigDict,
      name: str) -> List[tf.Tensor]:
    """Creates the model."""
    self.verify_input_shapes(inputs, cfg)
    input_network = inputs[0]

    # Create the embedding.
    network = input_network
    network = LastChannelOneHot(cfg.num_values, name=name+'/one_hot')(network)
    network = tf.keras.layers.Conv2D(
        filters=cfg.size,
        kernel_size=cfg.kernel_size,
        activation=cfg.activation if cfg.activation else None,
        strides=1,
        padding='same',
        use_bias=False,
        dilation_rate=1,
        name=name + '/Conv2D')(network)

    # Remove the categorical slice from the original input.
    input_network = RemoveLastChannel(
        name=name + '/remove_categorical_slice')(input_network)
    network = tf.keras.layers.Concatenate(
        axis=-1,
        name=name + '/concatenate')([input_network, network])
    return [network]


#
# Code that adds the EPD-specific layers to the model.
#

"""Generates an (e)ncoder-(p)rocessor-(d)ecoder model.

There is an encoder stage, which is just a series of convolutional blocks, a
processor stage, which is a repeated set of stacked convolutions with residual
links, and a decoder stage, which is also just a series of convolutional blocks.

This structure encompasses the network architecture described in:

Learning General-purpose CNN-based simulators for astrophysical turbulence.
Sanchez-Gonzalez, et al., 2021.
https://simdl.github.io/files/26.pdf

The add_predecoder_phase(...) function allows for the injection of callable that
will transform the output of the processor before sending it to the decoder.
The predecoder phase is allowed to return additional state (e.g., the state of
an RNN), which will be returned as additional outputs to the model.
"""


def get_regularization_params(
    l1: float, l2: float, kind: str) -> Dict[str, float]:
  """Returns the regularization parameters.

  These parameters can be used to specify regularization in Keras layers, e.g.,
  for the Conv2D layer.

  Args:
    l1: The amount of l1 regularization to use.
    l2: The amount of l2 regularization to use.
    kind: The type of regularization.
      'kernel':  Only kernel weights will be regularized.
      'bias': Only the bias weights will be regularized.
      'activity': Only the output of the layer will be regularized.
      'all': Regularize kernel, bias and activity.
      'weights': Regularize both kernel and bias.
  """
  # Determine which regularizer to use.
  layer_params = {}
  if not l1 and not l2:
    return layer_params
  if l1 and not l2:
    which_reg = tf.keras.regularizers.l1
    reg_params = {'l1': l1}
  elif not l1 and l2:
    which_reg = tf.keras.regularizers.l2
    reg_params = {'l2': l2}
  else:
    which_reg = tf.keras.regularizers.l1_l2
    reg_params = {'l1': l1, 'l2': l2}

  # Determine the parameters for the regularizer.
  if kind not in ['kernel', 'bias', 'activity', 'all', 'weights']:
    raise ValueError('Unexpected kind: ', kind)
  if kind in ['kernel', 'all', 'weights']:
    layer_params.update({'kernel_regularizer': which_reg(**reg_params)})
  if kind == ['bias', 'all', 'weights']:
    layer_params.update({'bias_regularizer': which_reg(**reg_params)})
  if kind == ['activity', 'all']:
    layer_params.update({'activity_regularizer': which_reg(**reg_params)})

  return layer_params


def get_short_skip_link_2d(
    use_bias: bool,
    start: tf.Tensor,
    end: tf.Tensor,
    name: str) -> tf.Tensor:
  """Adds a skip link between start and end.

  A 2D convolution is used to change the channel shape of start to match that of
  end.  Both start and end are assumed to be rank 4 (bhwc).

  Args:
    use_bias: the use_bias parameter for Conv2D.
    start: the start of the skip link.
    end: the end of the skip link.
    name: A prefix to give the names of the layers added.

  Returns:
    The network after the skip link is added to end.
  """
  if len(start.shape) != 4:
    raise ValueError('Only supporting start shapes of rank 4 currently.')
  if len(end.shape) != 4:
    raise ValueError('Only supporting end shapes of rank 4 currently.')
  num_filters = end.get_shape()[-1]
  network = tf.keras.layers.Conv2D(
      filters=num_filters,
      kernel_size=1,
      strides=1,
      padding='same',
      use_bias=use_bias,
      name=name + '/CONV2D')(start)
  network = tf.keras.layers.Add(name=name + '/ADD')([network, end])
  return network


def verify_singular_square_inputs(inputs: Sequence[tf.Tensor]):
  """Verifies there is a single 4D (bhwc) input with square h and w.

  Args:
    inputs: The inputs.

  Raises:
    ValueError if conditions are not met.
  """
  if len(inputs) != 1:
    raise ValueError(
        f'There must be exactly one input for this model: {len(inputs)}.')
  if len(inputs[0].shape) != 4:
    raise ValueError(f'There must be exactly 4 dimensions in the data (bhwc): '
                     f'{len(inputs[0].shape)}')
  if inputs[0].shape[1] != inputs[0].shape[2]:
    raise ValueError(f'Expecting square spatial dimensions: {inputs[0].shape}')


class EPD2DModel():
  """An encode-processor-decoder model."""

  @classmethod
  def _get_encoder_config(cls) -> config_dict.ConfigDict:
    """Gets the config for the encoder."""
    cfg = config_dict.ConfigDict()

    # If true, perform a batch norm after each convolution.
    cfg.batch_norm = False

    # The activation unit to use after each convolution.
    cfg.act = 'ReLU'

    # The number of filters for each convolution in the encoder (and thus, the
    # number of convolutions).
    cfg.filters = (32, 32, 3)

    # The size of the square 2D kernel to use.
    cfg.kernel = 3

    return cfg

  @classmethod
  def _get_processor_config(cls) -> config_dict.ConfigDict:
    """Gets the config for the processor."""
    cfg = config_dict.ConfigDict()

    # The number of blocks in the processor.
    cfg.blocks = 4

    # If true, perform a batch norm after each block in the processor.
    cfg.batch_norm = False

    # The activation unit to use after each convolution within the block.
    cfg.act = 'ReLU'

    # The number of filters to use in each convolution block of the processor.
    cfg.filters = (32, 32, 32, 32, 32)

    # The amount of dilation in each of the convolutions.  Must have same number
    # of elements as filters.
    cfg.dilations = (1, 2, 4, 2, 1)

    # The size of the square 2D kernel to use.
    cfg.kernel = 3

    return cfg

  @classmethod
  def _get_decoder_config(cls) -> config_dict.ConfigDict:
    """Gets the config for the encoder."""
    cfg = config_dict.ConfigDict()

    # If true, perform a batch norm after each convolution.
    cfg.batch_norm = False

    # The activation unit to use after each convolution.
    cfg.act = 'ReLU'

    # The number of filters to use in each convolution in the decoder (and thus,
    # the number of convolutions in the decoder).
    cfg.filters = (32, 64, 32)

    # The size of the square 2D kernel to use.
    cfg.kernel = 3

    return cfg

  @classmethod
  def _get_reg_config(cls) -> config_dict.ConfigDict:
    """Gets the config for regularization."""
    cfg = config_dict.ConfigDict()

    # If non-zero, the amount of l1 and l2 regularization to add to Keras
    # layers.  If both of these are zero, no regularization occurs.
    cfg.l1 = 0.0
    cfg.l2 = 0.0

    # Which type of regularization to add.  Valid values are: 'kernel', 'bias',
    #  'activity', 'all' or 'weights' (both kernel and bias, but not activity).
    #  See https://keras.io/api/layers/regularizers/ for details.
    cfg.kind = 'all'

    return cfg

  @classmethod
  def get_default_config(cls)-> config_dict.ConfigDict:
    """Gets the config."""
    cfg = config_dict.ConfigDict()

    # The settings for the encoder.
    cfg.encoder = EPD2DModel._get_encoder_config()

    # The settings for the processor.
    cfg.processor = EPD2DModel._get_processor_config()

    # The settings for the decoder.
    cfg.decoder = EPD2DModel._get_decoder_config()

    # If true, all convolutions will train a bias.
    cfg.use_bias = True

    # If true, apply batch norm on the input.
    cfg.input_batch_norm = False

    # The fraction of parameters to dropout after every convolution in the
    # encoder, decoder and processor (does not effect the final convolution).
    # Leave as 0.0 for no dropout.
    cfg.dropout = 0.0

    # The final activation to use after reducing the decoder's output to the
    # desired number of outputs.
    cfg.final_act = 'ReLU'

    # Specifies the type of loss function regularization to add to the various
    # Conv2D layers in the network.
    cfg.reg = EPD2DModel._get_reg_config()

    # The number of output channels to have (i.e., the number of filters in the
    # model's final 2D convolution).
    cfg.num_output_channels = 1

    cfg.lock()

    return cfg

  def get_custom_objects(self) -> Dict[str, Any]:
    # This model uses no custom objects.
    return {}

  def _add_encoder(
      self,
      cfg: config_dict.ConfigDict,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds the encoder stage."""
    for i, curr_filter in enumerate(cfg.encoder.filters):
      if cfg.encoder.act and i != 0:
        network = getattr(
            tf.keras.layers, cfg.encoder.act)(name=f'{name}_{i}/Act')(network)
      network = tf.keras.layers.Conv2D(
          filters=curr_filter,
          kernel_size=cfg.encoder.kernel,
          strides=1,
          padding='same',
          use_bias=cfg.use_bias,
          dilation_rate=1,  # no dilation ever in the encoder
          name=f'{name}_{i}',
          **get_regularization_params(
              cfg.reg.l1, cfg.reg.l2, cfg.reg.kind))(network)
      if cfg.encoder.batch_norm:
        network = tf.keras.layers.BatchNormalization(
            name=f'{name}_{i}/BN')(network)
      if cfg.dropout != 0.0:
        network = tf.keras.layers.Dropout(
            cfg.dropout, name=f'{name}_{i}/Dropout')(network)
    return network

  def _add_decoder(
      self,
      cfg: config_dict.ConfigDict,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds the decoder stage."""
    for i, curr_filter in enumerate(cfg.decoder.filters):
      if cfg.decoder.act and i != 0:
        network = getattr(
            tf.keras.layers, cfg.encoder.act)(name=f'{name}_{i}/Act')(network)
      network = tf.keras.layers.Conv2D(
          filters=curr_filter,
          kernel_size=cfg.decoder.kernel,
          strides=1,
          padding='same',
          use_bias=cfg.use_bias,
          dilation_rate=1,  # no dilation ever in the decoder
          name=f'{name}_{i}/Conv2D',
          **get_regularization_params(
              cfg.reg.l1, cfg.reg.l2, cfg.reg.kind))(network)
      if cfg.decoder.batch_norm:
        network = tf.keras.layers.BatchNormalization(name=f'{name}_{i}/BN')(
            network)
      if cfg.dropout > 0.0:
        network = tf.keras.layers.Dropout(
            cfg.dropout, name=f'{name}_{i}/Dropout')(network)
    return network

  def _add_processor_block(
      self,
      cfg: config_dict.ConfigDict,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds a single processor block to the end of network."""
    # Record the current network as the start of the skip link.
    skip_start = network

    # Add the convolution layers.
    for i, (curr_filter, curr_dilation) in enumerate(zip(
        cfg.processor.filters, cfg.processor.dilations)):
      if cfg.processor.act and i != 0:
        network = getattr(
            tf.keras.layers, cfg.processor.act)(
                name=name + f'/Act_{i}')(network)
      network = tf.keras.layers.Conv2D(
          filters=curr_filter,
          kernel_size=cfg.processor.kernel,
          strides=1,
          padding='same',
          use_bias=cfg.use_bias,
          dilation_rate=curr_dilation,
          name=name + f'/Conv2D_{i}',
          **get_regularization_params(
              cfg.reg.l1, cfg.reg.l2, cfg.reg.kind))(network)
      if cfg.processor.batch_norm:
        network = tf.keras.layers.BatchNormalization(name=name + f'/BN_{i}')(
            network)
      if cfg.dropout > 0.0:
        network = tf.keras.layers.Dropout(
            cfg.dropout, name=f'{name}/Dropout_{i}')(network)

    # Add the skip link.
    network = get_short_skip_link_2d(
        cfg.use_bias, skip_start, network, name + '/skip')

    return network

  def _add_processor(
      self,
      cfg: config_dict.ConfigDict,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds the processor stage."""
    for i in range(cfg.processor.blocks):
      network = self._add_processor_block(cfg, network, name + f'/block_{i}')
    return network

  def _add_tail(
      self,
      cfg: config_dict.ConfigDict,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds a final convolution and activation layer to the network."""
    if cfg.final_act:
      network = getattr(
          tf.keras.layers, cfg.final_act)(name=name + '/final_act')(network)
    network = tf.keras.layers.Conv2D(
        filters=cfg.num_output_channels,
        kernel_size=1,
        strides=1,
        padding='same',
        name=name + '/final_conv2d',
        use_bias=cfg.use_bias,
        **get_regularization_params(
            cfg.reg.l1, cfg.reg.l2, cfg.reg.kind))(network)
    return network

  def _verify_cfg(self, cfg: config_dict.ConfigDict):
    if len(cfg.processor.filters) != len(cfg.processor.dilations):
      raise ValueError('The number of filters and dilations must be the same.')

  def _add_input_processing(
      self,
      cfg: config_dict.ConfigDict, network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds any processing to the input the model needs."""
    if cfg.input_batch_norm:
      network = tf.keras.layers.BatchNormalization(
          name=name + '/input_bn')(network)
    return network

  def verify_input_shapes(
      self, inputs: Sequence[tf.Tensor], cfg: Optional[config_dict.ConfigDict]):
    verify_singular_square_inputs(inputs)

  def create(
      self,
      inputs: Sequence[tf.Tensor],
      cfg: config_dict.ConfigDict,
      name: str) -> List[tf.Tensor]:
    """Creates the encoder-processor-decoder model."""
    # Verify the input.
    self._verify_cfg(cfg)
    self.verify_input_shapes(inputs, cfg)

    # Create the model.
    network = inputs[0]
    network = self._add_input_processing(
        cfg, network, name + '/InputProcessing')
    network = self._add_encoder(cfg, network, name + '/Encoder')
    network = self._add_processor(cfg, network, name + '/Processor')
    network = self._add_decoder(cfg, network, name + '/Decoder')
    network = self._add_tail(cfg, network, name + '/Tail')

    return [network]


#
# Code to create the EPD model used in the paper.
#


def get_config() -> config_dict.ConfigDict:
  """Gets the config."""
  cfg = config_dict.ConfigDict()

  # If true, replace the categorical input with an embedding.
  cfg.use_embedding = False

  # The cfg for the embedding layer.  Not used if cfg.embedding is false.
  cfg.embedding = ChannelEncodingModel.get_default_config()

  # The cfg for the epd model.
  cfg.epd_2d = EPD2DModel.get_default_config()

  cfg.lock()

  return cfg


def get_custom_objects() -> Dict[str, Any]:
  """Returns the custom objects used by this model."""
  custom_objects = {}
  custom_objects.update(
      ChannelEncodingModel().get_custom_objects())
  custom_objects.update(EPD2DModel().get_custom_objects())
  return custom_objects


def create_head(
    input_networks: List[tf.Tensor],
    cfg: config_dict.ConfigDict,
    expected_shape_wo_batch: Tuple[int, int, int],
) -> List[tf.Tensor]:
  """Creates the head of the model."""
  if len(expected_shape_wo_batch) != 3:
    raise ValueError('Expected shape should be (h, w, c) (no batch)')

  # Defines the base name for the model.
  name = 'epd'

  if len(input_networks) != 1:
    raise ValueError('Expecting a single input.')
  network = input_networks[0]
  network = tf.ensure_shape(network, (None,) + expected_shape_wo_batch)
  logging.info(f'The input shape was: {network.shape}')

  # Add the embedding layer, if needed.
  if cfg.use_embedding:
    channel_encoding_ml_model = ChannelEncodingModel()
    [network] = channel_encoding_ml_model.create([network],
                                                 cfg.embedding,
                                                 name='embedding')

  # Add the EPD model.
  epd_2d_ml_model = epd_2d.EPD2DModel()
  [network] = epd_2d_ml_model.create([network],
                                     cfg.epd_2d,
                                     name=name + '/epd_2d')

  return [network]


# Code to call to create the model.
def create_model(input_shape: tf.TensorShape,
                 label_shape: tf.TensorShape) -> tf.keras.Model:
  """Gets the desired model.

  Args:
    cfg: The entire config.
    input_shape: The shape of the input.
    label_shape: The shape of the label.

  Returns:
    A tuple containing: 1) the model
  """

  # Create the input, ensuring the input has the correct shape.
  input_layers = []
  the_input = tf.keras.Input(
      shape=input_shape[1:], name='input_image'  # exclude batch dimension
  )
  input_layers.append(the_input)

  output_layers = create_head(
      input_layers, get_config(), input_shape[1:])

  keras_model = tf.keras.Model(input_layers, output_layers)

  optimizer = # Set up the optimizer you want.
  metrics = # Set up the metrics you want.
  loss_function = # Set up the loss function you want.

  keras_model.compile(
      optimizer=optimizer,
      loss=loss_functions,
      metrics=metrics_per_output,
      run_eagerly=cfg.training.train_in_eager_mode)
  keras_model.summary()

  return keras_model
